import os
import gzip
from pysam import VariantFile
from pylab import *
from scipy.stats import fisher_exact, rankdata
from scipy.special import bdtrc, chdtrc
from scipy.optimize import minimize


assembly = "hg38"


def read_significance():
    filename = 'enhancers.deseq.txt'
    print("Reading %s" % filename)
    handle = open(filename)
    line = next(handle)
    words = line.split()
    assert words[0] == "enhancer"
    assert words[1] == "00hr_basemean"
    assert words[2] == "00hr_log2fc"
    assert words[3] == "00hr_pvalue"
    assert words[4] == "01hr_basemean"
    assert words[5] == "01hr_log2fc"
    assert words[6] == "01hr_pvalue"
    assert words[7] == "04hr_basemean"
    assert words[8] == "04hr_log2fc"
    assert words[9] == "04hr_pvalue"
    assert words[10] == "12hr_basemean"
    assert words[11] == "12hr_log2fc"
    assert words[12] == "12hr_pvalue"
    assert words[13] == "24hr_basemean"
    assert words[14] == "24hr_log2fc"
    assert words[15] == "24hr_pvalue"
    assert words[16] == "96hr_basemean"
    assert words[17] == "96hr_log2fc"
    assert words[18] == "96hr_pvalue"
    assert words[19] == "all_basemean"
    assert words[20] == "all_log2fc"
    assert words[21] == "all_pvalue"
    basemeans = {}
    ppvalues = {}
    for line in handle:
        words = line.split()
        assert len(words) == 22
        name = words[0]
        basemean = float(words[19])
        basemeans[name] = basemean
        log2fc = float(words[20])
        pvalue = float(words[21])
        ppvalue = -log10(pvalue) * sign(log2fc)
        ppvalues[name] = ppvalue
    handle.close()
    return ppvalues

def read_expression(ppvalues):
    filename = "enhancers.expression.txt"
    print("Reading", filename)
    handle = open(filename)
    line = next(handle)
    words = line.split()
    assert words[0] == 'enhancer'
    assert words[1] == 'HiSeq_t00_r1'
    assert words[2] == 'HiSeq_t00_r2'
    assert words[3] == 'HiSeq_t00_r3'
    assert words[4] == 'HiSeq_t01_r1'
    assert words[5] == 'HiSeq_t01_r2'
    assert words[6] == 'HiSeq_t04_r1'
    assert words[7] == 'HiSeq_t04_r2'
    assert words[8] == 'HiSeq_t04_r3'
    assert words[9] == 'HiSeq_t12_r1'
    assert words[10] == 'HiSeq_t12_r2'
    assert words[11] == 'HiSeq_t12_r3'
    assert words[12] == 'HiSeq_t24_r1'
    assert words[13] == 'HiSeq_t24_r2'
    assert words[14] == 'HiSeq_t24_r3'
    assert words[15] == 'HiSeq_t96_r1'
    assert words[16] == 'HiSeq_t96_r2'
    assert words[17] == 'HiSeq_t96_r3'
    assert words[18] == 'CAGE_00_hr_A'
    assert words[19] == 'CAGE_00_hr_C'
    assert words[20] == 'CAGE_00_hr_G'
    assert words[21] == 'CAGE_00_hr_H'
    assert words[22] == 'CAGE_01_hr_A'
    assert words[23] == 'CAGE_01_hr_C'
    assert words[24] == 'CAGE_01_hr_G'
    assert words[25] == 'CAGE_04_hr_C'
    assert words[26] == 'CAGE_04_hr_E'
    assert words[27] == 'CAGE_12_hr_A'
    assert words[28] == 'CAGE_12_hr_C'
    assert words[29] == 'CAGE_24_hr_C'
    assert words[30] == 'CAGE_24_hr_E'
    assert words[31] == 'CAGE_96_hr_A'
    assert words[32] == 'CAGE_96_hr_C'
    assert words[33] == 'CAGE_96_hr_E'
    names = []
    hiseq = []
    cage = []
    for line in handle:
        words = line.split()
        assert len(words) == 34
        name = words[0]
        dataset, locus = name.split("|")
        assert name in ppvalues
        names.append(name)
        hiseq_count = 0
        for word in words[1:18]:
            count1, count2 = word.split(",")
            hiseq_count += int(count1)
            hiseq_count += int(count2)
        hiseq.append(hiseq_count)
        cage_count = 0
        for word in words[18:34]:
            count1, count2 = word.split(",")
            cage_count += int(count1)
            cage_count += int(count2)
        cage.append(cage_count)
    handle.close()
    hiseq = array(hiseq)
    cage = array(cage)
    return names, cage, hiseq

def read_refseq_chromosomes(assembly):
    directory = "/osc-fs_home/scratch/mdehoon/Data/UCSC/"
    filename = "chromAlias.txt.gz"
    path = os.path.join(directory, assembly, filename)
    print("Reading", path)
    stream = gzip.open(path, "rt")
    refseq_chromosomes = {}
    for line in stream:
        words = line.split()
        if words[2] == "refseq":
            refseq_chromosome = words[0]
            ucsc_chromosome = words[1]
            refseq_chromosomes[ucsc_chromosome] = refseq_chromosome
    stream.close()
    return refseq_chromosomes

def read_chromosome_sizes(assembly):
    directory = "/osc-fs_home/scratch/mdehoon/Data/Genomes"
    filename = "%s.chrom.sizes" % assembly
    path = os.path.join(directory, assembly, filename)
    handle = open(path)
    sizes = {}
    for line in handle:
        chromosome, size = line.split()
        if chromosome.endswith("_alt"):
            continue
        sizes[chromosome] = int(size)
    handle.close()
    return sizes

def count_background(snps, gwas, ucsc_chromosomes):
    snpcount = 0
    gwascount = 0
    for ucsc_chromosome in ucsc_chromosomes:
        print("Counting SNPs and GWAS SNPs on %s" % ucsc_chromosome)
        try:
            refseq_chromosome = refseq_chromosomes[ucsc_chromosome]
        except KeyError:
            continue
        try:
            snp_variants = snps.fetch(refseq_chromosome)
        except ValueError:
            snp_variants = None
        try:
            gwas_variants = gwas.fetch(refseq_chromosome)
        except ValueError:
            gwas_variants = None
        if snp_variants is None:
            assert gwas_variants is None
            continue
        else:
            assert gwas_variants is not None
        regions = []
        for variant in snp_variants:
            assert variant.chrom == refseq_chromosome
            region = (variant.start, variant.stop)
            regions.append(region)
        print("Found %d snpregions" % len(regions))
        snpregions = set(regions)
        print("Found %d unique snpregions" % len(snpregions))
        snpcount += len(snpregions)
        regions = []
        for variant in gwas_variants:
            assert variant.chrom == refseq_chromosome
            region = (variant.start, variant.stop)
            regions.append(region)
        print("Found %d gwasregions" % len(regions))
        gwasregions = set(regions)
        print("Found %d unique gwasregions" % len(gwasregions))
        assert gwasregions.issubset(snpregions)
        gwascount += len(gwasregions)
    return snpcount, gwascount

def minus_loglikelihood(x, gwascounts, snpcounts, expression):
    a, b = x
    term1 = sum(gwascounts*log(a+b*expression))
    term2 = sum((a+b*expression)*snpcounts)
    # return negative log-likelihood
    return -(term1 - term2)
 
def minus_gradient(x, gwascounts, snpcounts, expression):
    a, b = x
    dLda = sum(gwascounts/(a+b*expression)) - sum(snpcounts)
    dLdb = sum(gwascounts*expression/(a+b*expression)) - sum(expression*snpcounts)
    # return gradient of negative log-likelihood
    return -array([dLda, dLdb])

def make_figure_expression_dependence(names, cage, hiseq, gwascount, snpcount, background_fraction):
    cage = rankdata(cage)
    hiseq = rankdata(hiseq)
    indices = array([index for index, name in enumerate(names) if snpcount[name] > 0])
    cage = cage[indices]
    hiseq = hiseq[indices]
    gwascount = array([gwascount[names[index]] for index in indices])
    snpcount = array([snpcount[names[index]] for index in indices])
    figure(figsize=(8.0, 4.0))
    a = 0.0002
    b = 1e-9
    x0 = (a, b)
    args = (gwascount, snpcount, hiseq)
    result = minimize(minus_loglikelihood, x0, args, jac=minus_gradient)
    a1, b1 = result.x
    loglikelihood_h1 = - minus_loglikelihood((a1, b1), *args)
    a0 = sum(gwascount) / sum(snpcount)
    b0 = 0
    loglikelihood_h0 = - minus_loglikelihood((a0, b0), *args)
    statistic = 2 * (loglikelihood_h1 - loglikelihood_h0)
    pvalue = chdtrc(1, statistic)
    print("HiSeq: Poisson regression, intercept = %g, slope = %g, p-value = %g" % (a1, b1, pvalue))
    window = 5000
    indices = argsort(hiseq)
    n = len(indices) - window + 1
    x = zeros(n)
    y = zeros(n)
    for i in range(n):
        j = i + window
        numerator = sum(gwascount[indices[i:j]])
        denominator = sum(snpcount[indices[i:j]])
        fraction = numerator / denominator
        y[i] = fraction
        x[i] = mean(hiseq[indices[i:j]])
    subplot(121)
    plot(x, y, color='red')
    x = array([0, n-1])
    y = a1 + b1 * x
    plot(x, y, color='red', linestyle='--')
    plot(x, [background_fraction, background_fraction], color='black')
    text(0, background_fraction, "background", horizontalalignment='right', verticalalignment='center', fontsize=8)
    xlim(*x)
    ylim(0.00022, 0.00069)
    xticks(fontsize=8)
    yticks(fontsize=8)
    xlabel("Expression rank of enhancers sorted by short\ncapped RNA expression (single-end libraries)", fontsize=8)
    label = "Fraction of SNP loci associated with GWAS traits\n(moving average over %d enhancers)" % window
    ylabel(label, fontsize=8, labelpad=12)
    title("Short capped RNAs (single-end libraries)", color='red', fontsize=8)
    args = (gwascount, snpcount, cage)
    result = minimize(minus_loglikelihood, x0, args, jac=minus_gradient)
    a1, b1 = result.x
    loglikelihood_h1 = - minus_loglikelihood((a1, b1), *args)
    a0 = sum(gwascount) / sum(snpcount)
    b0 = 0
    loglikelihood_h0 = - minus_loglikelihood((a0, b0), *args)
    statistic = 2 * (loglikelihood_h1 - loglikelihood_h0)
    pvalue = chdtrc(1, statistic)
    print("CAGE: Poisson regression, intercept = %g, slope = %g, p-value = %g" % (a1, b1, pvalue))
    window = 5000
    indices = argsort(cage)
    n = len(indices) - window + 1
    x = zeros(n)
    y = zeros(n)
    for i in range(n):
        j = i + window
        numerator = sum(gwascount[indices[i:j]])
        denominator = sum(snpcount[indices[i:j]])
        fraction = numerator / denominator
        y[i] = fraction
        x[i] = mean(cage[indices[i:j]])
    ax = subplot(122)
    ax.yaxis.set_label_position("right")
    ax.yaxis.tick_right()
    plot(x, y/background_fraction, color='blue')
    x = array([0, n-1])
    y = a1 + b1 * x
    plot(x, y/background_fraction, color='blue', linestyle='--')
    plot(x, [1.0, 1.0], color='black')
    xlim(*x)
    ylim(0.00022/background_fraction, 0.00069/background_fraction)
    xticks(fontsize=8)
    yticks(fontsize=8)
    xlabel("Expression rank of enhancers sorted by long\ncapped RNA expression (CAGE libraries)", fontsize=8)
    label = "GWAS SNP enrichment relative to background\n(moving average over %d enhancers)" % window
    ylabel(label, fontsize=8)
    title("Long capped RNAs (CAGE libraries)", color='blue', fontsize=8)
    subplots_adjust(left=0.15, bottom=0.15, wspace=0.10)
    filename = "figure_expression_gwas.png"
    print("Saving figure as %s" % filename)
    savefig(filename)
    filename = "figure_expression_gwas.svg"
    print("Saving figure as %s" % filename)
    savefig(filename)


sizes = read_chromosome_sizes(assembly)
refseq_chromosomes = read_refseq_chromosomes(assembly)
ppvalues = read_significance()
names, cage, hiseq = read_expression(ppvalues)

directory = "/osc-fs_home/scratch/mdehoon/Data/NCBI/dbSNP/"
filename = "GCF_000001405.38.bcf"
path = os.path.join(directory, filename)
print("Reading", path)
snps = VariantFile(path)

path = "gwas.bcf"
print("Reading", path)
gwas = VariantFile(path)


background_snp_count, background_gwas_count = count_background(snps, gwas, sizes)
print("Background: %d SNPs, %d GWAS SNPs" % (background_snp_count, background_gwas_count))
background_fraction = background_gwas_count / background_snp_count
print("Background: GWAS / SNP ratio = %f" % background_fraction)

print("Counting enhancer overlap")
snpcount = {}
gwascount = {}
for name in names:
    dataset, locus = name.split("|")
    chromosome, start_end = locus.split(":")
    refseq_chromosome = refseq_chromosomes[chromosome]
    start, end = start_end.split("-")
    start = int(start)
    end = int(end)
    regions = []
    for variant in snps.fetch(refseq_chromosome, start, end):
        assert variant.chrom == refseq_chromosome
        region = (variant.start, variant.stop)
        regions.append(region)
    snpregions = set(regions)
    snpcount[name] = len(set(regions))
    regions = []
    for variant in gwas.fetch(refseq_chromosome, start, end):
        assert variant.chrom == refseq_chromosome
        region = (variant.start, variant.stop)
        regions.append(region)
    gwasregions = set(regions)
    gwascount[name] = len(set(regions))
    assert gwasregions.issubset(snpregions)

make_figure_expression_dependence(names, cage, hiseq, gwascount, snpcount, background_fraction)

hiseq = dict(zip(names, hiseq))
cage = dict(zip(names, cage))

raise Exception

enriched_names = {'short': [], 'long': []}
ppvalue_threshold = -log10(0.05)
for name in ppvalues:
    ppvalue = ppvalues[name]
    if ppvalue > ppvalue_threshold:
        key = "short"
    elif ppvalue < -ppvalue_threshold:
        key = "long"
    else:
        continue
    enriched_names[key].append(name)

n = len(enriched_names['short'])
m = len(enriched_names['long'])
assert n < m
enriched_names['short'].sort(key=hiseq.get, reverse=True)
enriched_names['long'].sort(key=cage.get, reverse=True)

fractions = []
standard_errors = []
for category in ['short', 'long']:
    names = enriched_names[category]
    numerator = sum([gwascount[name] for name in names])
    denominator = sum([snpcount[name] for name in names])
    fraction = numerator / denominator
    fractions.append(fraction)
    standard_error = sqrt((1/numerator) + (1/denominator))
    standard_errors.append(standard_error)

fractions = array(fractions)
standard_errors = array(standard_errors)

contingency = zeros((2, 2), int)
contingency[0,0] = sum([gwascount[name] for name in enriched_names['long']])
contingency[0,1] = sum([snpcount[name] for name in enriched_names['long']])
contingency[1,0] = sum([gwascount[name] for name in enriched_names['short']])
contingency[1,1] = sum([snpcount[name] for name in enriched_names['short']])
oddsratio, pvalue = fisher_exact(contingency)
print("short vs long: Fisher-exact pvalue = %.3f" % pvalue)

f = figure(figsize=(3.5,3.5))

x = arange(len(fractions))
alpha = 0.2 + (1 - 0.2) * 0.2
yerr = fractions * standard_errors
bar(x, fractions, color=['red', 'blue'], alpha=alpha, yerr=yerr)
labels = ['Significantly enriched\nin short capped RNA\n(single-end) libraries\n($N = %d$)' % n,
          'Significantly enriched\nin long capped RNA\n(CAGE) libraries\n($N = %d$)' % m]
xticks(x, labels, fontsize=8)
yticks(fontsize=8)
ylabel("Fraction of SNP loci\nassociated with GWAS traits", fontsize=8, labelpad=20)

i = 0
ratio_labels = []
for category in ('short', 'long'):
    names = enriched_names[category]
    numerator = sum([gwascount[name] for name in names])
    denominator = sum([snpcount[name] for name in names])
    label = "%d /\n%d" % (numerator, denominator)
    ratio_labels.append(label)
    pvalue = bdtrc(numerator-1, denominator, background_fraction)
    enrichment = (numerator/denominator) / background_fraction
    print("%d enhancers enriched for %s capped RNAs: enrichment = %f, binomial test, p = %.3g" % (len(names), category, enrichment, pvalue))
    i += 1

xmin, xmax = xlim()
plot([xmin, xmax], [background_fraction, background_fraction], 'k--')
text(xmin, background_fraction, "background ", color='k', fontsize=8, verticalalignment='center', horizontalalignment='right')

ymin, ymax = ylim()
ymax *= 1.10
ylim(0, ymax)
xlim(xmin, xmax)

ax_t = f.axes[0].secondary_xaxis('top')
ax_t.set_xticks([0,1],ratio_labels, fontsize=8)

ax2 = f.axes[0].twinx()
ax2.set_ylabel("GWAS SNP enrichment", fontsize=8)
ax2.xaxis.tick_top()

ymax /= (background_gwas_count / background_snp_count)
ylim(0, ymax)
yticks(fontsize=8)


subplots_adjust(left=0.3,right=0.65,bottom=0.2,top=0.8)


filename = "figure_enhancer_gwas.png"
print("Saving figure to %s" % filename)
savefig(filename)

filename = "figure_enhancer_gwas.svg"
print("Saving figure to %s" % filename)
savefig(filename)
